{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"top\"></a><img src=\"images/chisel_1024.png\" alt=\"Chisel logo\" style=\"width:480px;\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 4.2: FIRRTL AST Traversal\n",
    "\n",
    "**Prev: [Introduction to FIRRTL](4.1_firrtl_ast.ipynb)**<br>\n",
    "**Next: [Common Pass Idioms](4.3_firrtl_common_idioms.ipynb)**\n",
    "\n",
    "### Understanding IR node children\n",
    "\n",
    "Writing a Firrtl pass usually requires writing functions which walk the Firrtl datastructure to either collect information or replace IR nodes with new IR nodes.\n",
    "\n",
    "The IR datastructure is a tree, where each IR node can have some number of children nodes (which in turn can have more children nodes, etc.). IR nodes without children are called leaves.\n",
    "\n",
    "Different IR nodes can have different children types. The following table shows the possible children type for each IR node type:\n",
    "\n",
    "```\n",
    "+------------+-----------------------------+\n",
    "|    Node    |          Children           |\n",
    "+------------+-----------------------------+\n",
    "| Circuit    | DefModule                   |\n",
    "| DefModule  | Port, Statement             |\n",
    "| Port       | Type, Direction             |\n",
    "| Statement  | Statement, Expression, Type |\n",
    "| Expression | Expression, Type            |\n",
    "| Type       | Type, Width                 |\n",
    "| Width      |                             |\n",
    "| Direction  |                             |\n",
    "+------------+-----------------------------+\n",
    "```\n",
    "\n",
    "### The map function\n",
    "\n",
    "To write a function that traverses a `Circuit`, we need to first understand the functional programming concept `map`.\n",
    "\n",
    "#### Understanding Seq.map\n",
    "A Scala sequence of strings, can be represented as a tree with a root node `Seq` and children nodes `\"a\"`, `\"b\"`, and `\"c\"`:\n",
    "```scala\n",
    "val s = Seq(\"a\", \"b\", \"c\")\n",
    "```\n",
    "```\n",
    "    Seq\n",
    " /   |   \\\n",
    "\"a\" \"b\" \"c\"\n",
    "```\n",
    "\n",
    "Suppose we define a function `f` that, given a String argument `x`, concatenates `x` with itself:\n",
    "```scala\n",
    "def f(x: String): String = x + x\n",
    "```\n",
    "\n",
    "We can call `s.map` to return a new `Seq[String]` whose children are the result of applying `f` to every child of s:\n",
    "```scala\n",
    "val s = Seq(\"a\", \"b\", \"c\")\n",
    "def f(x: String): String = x + x  // repeated declaration for clarity\n",
    "val t = s.map(f)\n",
    "println(t) // Seq(\"aa\", \"bb\", \"cc\")\n",
    "```\n",
    "```\n",
    "     Seq\n",
    " /    |    \\\n",
    "\"aa\" \"bb\" \"cc\"\n",
    "```\n",
    "\n",
    "#### Understanding Firrtl's map\n",
    "\n",
    "We use this \"mapping\" idea to create our own, custom `map` methods on IR nodes. Suppose we have a `DoPrim` expression representing 1 + 1; this can be depicted as a tree of expressions with a root node `DoPrim`:\n",
    "```\n",
    "        DoPrim\n",
    "     /          \\\n",
    "UIntValue    UIntValue\n",
    "```\n",
    "\n",
    "If we have a function `f` that takes an `Expression` argument and returns a new `Expression`, we can \"map\" it onto all children `Expression` of a given IR node, like our `DoPrim`. This would return the following new `DoPrim`, whose children are the result of applying `f` to every `Expression` child of `DoPrim`:\n",
    "```\n",
    "        DoPrim\n",
    "     /          \\\n",
    "f(UIntValue)    f(UIntValue)\n",
    "```\n",
    "\n",
    "Sometimes IR nodes have children of multiple types. For example, `Conditionally` has both `Expression` and `Statement` children. In this case, the map will only apply its function to the children whose type matches the function's argument type (and return value type):\n",
    "```scala\n",
    "val c = Conditionally(info, e, s1, s2) // e: Expression, s1, s2: Statement, info: FileInfo\n",
    "def fExp(e: Expression): Expression = ...\n",
    "def fStmt(s: Statement): Statement = ...\n",
    "c.map(fExp)  // Conditionally(fExp(e), s1, s2)\n",
    "c.map(fStmt) // Conditionally(e, fStmt(s1), fStmt(s2))\n",
    "```\n",
    "\n",
    "Scala has \"infix notation\", which allows you to drop the `.` and parenthesis when calling a function which has one argument. Often, we write these map functions with infix notation:\n",
    "```scala\n",
    "c map fExp  // equivalent to c.map(fExp)\n",
    "c map fStmt // equivalent to c.map(fStmt)\n",
    "```\n",
    "\n",
    "### Pre-order traversal\n",
    "\n",
    "To traverse a Firrtl tree, we use `map` to write recursive functions which visit every child of every node we care about.\n",
    "\n",
    "Suppose we want to collect the names of every register declared in the design; we know this requires visiting every `Statement`. However, some `Statement` nodes can have children `Statement`. Thus, we need to write a function that will both check if its input argument is a `DefRegister` and, if not, will recursively apply `f` to all `Statement` children of its input argument:\n",
    "\n",
    "The following function, `f`, is similar to our described function yet it takes two arguments: a mutable hashset of register names, and a `Statement`. Using function currying, we can pass only the first argument to return a new function with the desired type signature (`Statement=>Statement`):\n",
    "\n",
    "```scala\n",
    "def f(regNames: mutable.HashSet[String]())(s: Statement): Statement = s match {\n",
    "  // If register, add name to regNames\n",
    "  case r: DefRegister =>\n",
    "    regNames += r.name\n",
    "    r // Return argument unchanged (ok because DefRegister has no Statement children)\n",
    "  // If not, apply f(regNames) to all children Statement\n",
    "  case _ => s map f(regNames) // Note that f(regNames) is of type Statement=>Statement\n",
    "}\n",
    "```\n",
    "\n",
    "This pattern is very common in Firrtl, and is called \"pre-order traversal\" because the recursive function matches on the original IR node before recursively applying to its children nodes.\n",
    "\n",
    "### Post-order traversal\n",
    "\n",
    "We can write the previous example in a \"post-order traversal\" as follows:\n",
    "\n",
    "```scala\n",
    "def f(regNames: mutable.HashSet[String]())(s: Statement): Statement = \n",
    "  // Not we immediately recurse to the children nodes, then match\n",
    "  s map f(regName) match {\n",
    "    // If register, add name to regNames\n",
    "    case r: DefRegister =>\n",
    "      regNames += r.name\n",
    "      r // Return argument unchanged (ok because DefRegister has no Statement children)\n",
    "    // If not, return s\n",
    "    case _ => s // Note that all Statement children of s have had f(regNames) already applied\n",
    "  }\n",
    "```\n",
    "\n",
    "While the traversal ordering is different between these two examples, it makes no difference for this use case (and many others). However, it is an important tool to keep in your back pocket for when the traversal ordering matters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala211",
   "nbconvert_exporter": "script",
   "pygments_lexer": "scala",
   "version": "2.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
